{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data generation notebook\n",
    "used for inpection of the synthetic dataset and tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/arthur/.conda/envs/py38_cu11_2/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D #for plotting the 3-D plot.\n",
    "from src.data_generator.labels_generator import Label_generator\n",
    "from src.data_generator.data_gen_utils import *\n",
    "from src.motion_refiner_4D import Motion_refiner\n",
    "from src.functions import *\n",
    "from src.config import *\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=DeprecationWarning) \n",
    "# warnings.filterwarnings(\"ignore\", category=FutureWarning) \n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading BERT model... "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_projector.weight']\n",
      "- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "done\n",
      "loading CLIP model... done\n",
      "loading precomputed CLIP text embbedings... done\n",
      "loading precomputed CLIP img embbedings... done\n",
      "DEVICE:  cuda\n"
     ]
    }
   ],
   "source": [
    "mr = Motion_refiner(traj_n = traj_n, load_models = True, clip_only=False, locality_factor=True, poses_on_features=True, load_precomp_emb=True)\n",
    "obj_lib_file= image_dataset_folder+\"imagenet1000_clsidx_to_labels.txt\"\n",
    "dataset_name = \"test\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cartesian, distance and speed commands"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00,  2.04it/s]\n",
      "100%|██████████| 2/2 [00:01<00:00,  1.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8\n"
     ]
    }
   ],
   "source": [
    "\n",
    "dg = data_generator({'dist':1, 'cartesian':1, 'speed':1}, obj_lib_file= obj_lib_file, images_base_path=image_dataset_folder)\n",
    "data = dg.generate(2,4,N=[20,100],n_int=[5,15])\n",
    "n_map = 2\n",
    "labels_per_map = 4\n",
    "data = dg.generate(n_map,labels_per_map,N=[50,100],n_int=[3,15])\n",
    "print(len(data))\n",
    "\n",
    "## ------- processed data -------\n",
    "# X,Y = mr.prepare_data(data,deltas=False)\n",
    "# print(\"X: \",X.shape)\n",
    "# print(\"Y: \",Y.shape)\n",
    "# print(\"DONE computing embeddings\")\n",
    "# print(\"saving data...\")\n",
    "# ------- save pre processed data -------\n",
    "# mr.save_XY(X, Y, x_name=\"X\"+dataset_name,y_name=\"Y\"+dataset_name)\n",
    "# mr.save_data(data,data_name=\"data\"+dataset_name)\n",
    "# print(\"DONE\")\n",
    "%matplotlib qt\n",
    "\n",
    "data_sample = random.choices(data,k=3)\n",
    "show_data4D(data_sample)#,image_loader=mr.image_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Force commands"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  1.30it/s]\n",
      "100%|██████████| 10/10 [00:01<00:00,  5.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10\n"
     ]
    }
   ],
   "source": [
    "dg = data_generator({'force':1,'cartesian force':1}, obj_lib_file= obj_lib_file, images_base_path=image_dataset_folder)\n",
    "data = dg.generate(1,3,N=[20,100],n_int=[5,15])\n",
    "n_map = 10\n",
    "labels_per_map = 1\n",
    "data = dg.generate(n_map,labels_per_map,N=[50,100],n_int=[3,15],output_forces=True)\n",
    "print(len(data))\n",
    "\n",
    "# ------- processed data -------\n",
    "# X,Y = mr.prepare_data(data,deltas=False,output_forces=True, change_img_base=['/home/arthur/image_dataset/','/home/arthur/data/image_dataset//'])\n",
    "# print(\"X: \",X.shape)\n",
    "# print(\"Y: \",Y.shape)\n",
    "# print(\"DONE computing embeddings\")\n",
    "# print(\"saving data...\")\n",
    "# # ------- save pre processed data -------\n",
    "# mr.save_XY(X, Y, x_name=\"X\"+dataset_name,y_name=\"Y\"+dataset_name)\n",
    "# mr.save_data(data,data_name=\"data\"+dataset_name)\n",
    "# print(\"DONE\")\n",
    "%matplotlib qt\n",
    "\n",
    "# data_sample = random.choices(data,k=3)\n",
    "data_sample = data\n",
    "\n",
    "show_data4D(data_sample,plot_output=False)#,image_loader=mr.image_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading dataset:  forces_only_f ...done\n",
      "raw X: (100, 953) \tY: (100, 160)\n",
      "filtered X: (100, 953) \tY: (100, 160)\n"
     ]
    }
   ],
   "source": [
    "X,Y, data = mr.load_dataset(\"forces_only_f\", filter_data = True, base_path=data_folder)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('py38_cu11_2')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de91200c9f9e1f8a0c28ceba668014be0fd55838e84400e0a7ad1d269192773"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
